import h5py
import json
import argparse
import os
from shutil import copyfile
import robosuite
import xml.etree.ElementTree as ET

import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils

from robosuite.utils.mjcf_utils import find_elements

def replace_elem(parent, old_elem, new_elem):
    """
    code adapted from https://stackoverflow.com/a/20931505
    """
    parent_index = list(parent).index(old_elem)
    parent.remove(old_elem)
    parent.insert(parent_index, new_elem)

def convert_xml(old_xml_str, env_name, env):
    """
    Postprocess xml string generated by robosuite to be compatible with robosuite v1.3
    This script should not the xml string if it was already generated using robosuite v1.3
    Args:
        xml_str (str): xml string to process (from robosuite v1.2)
    """
    
    if env_name in ["PickPlaceCan", "NutAssemblySquare", "ToolHang"]:
        xml_str = env.env.sim.model.get_xml()
    elif env_name == "Lift":
        xml_str = env.env.sim.model.get_xml()
        # replace the cube_g0 and cube_g0_vis with elements in old_xml_str
        old_et = ET.ElementTree(ET.fromstring(old_xml_str)).getroot()
        new_et = ET.ElementTree(ET.fromstring(xml_str)).getroot()

        cube_new = find_elements(
            root=new_et,
            tags="body",
            attribs={"name": "cube_main"},
            return_first=True
        )

        cube_old = find_elements(
            root=old_et,
            tags="body",
            attribs={"name": "cube_main"},
            return_first=True
        )

        worldbody_new = find_elements(
            root=new_et,
            tags="worldbody",
            return_first=True
        )

        replace_elem(worldbody_new, cube_new, cube_old)

        xml_str = ET.tostring(new_et, encoding="utf8").decode("utf8")
    elif env_name == "TwoArmTransport":
        xml_str = env.env.sim.model.get_xml()
        # replace the cube_g0 and cube_g0_vis with elements in old_xml_str
        old_et = ET.ElementTree(ET.fromstring(old_xml_str)).getroot()
        new_et = ET.ElementTree(ET.fromstring(xml_str)).getroot()

        worldbody_new = find_elements(
            root=new_et,
            tags="worldbody",
            return_first=True
        )
        for bname in [
            "payload_root",
            
            ### ignore all these other following assets (makes playback worse for some reason...)
            # "trash_main",
            # "transport_start_bin_root", "transport_target_bin_root",
            # "transport_trash_bin_root", "transport_start_bin_lid_root"
        ]:
            body_new = find_elements(
                root=new_et,
                tags="body",
                attribs={"name": bname},
                return_first=True
            )

            body_old = find_elements(
                root=old_et,
                tags="body",
                attribs={"name": bname},
                return_first=True
            )

        replace_elem(worldbody_new, body_new, body_old)

        xml_str = ET.tostring(new_et, encoding="utf8").decode("utf8")

    return xml_str

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        help="path to input hdf5 dataset",
    )
    parser.add_argument(
        "--output_dataset",
        type=str,
        help="path to output hdf5 dataset",
    )
    args = parser.parse_args()
    
    args.dataset = os.path.expanduser(args.dataset)
    args.output_dataset = os.path.expanduser(args.output_dataset)
    
    assert args.output_dataset != args.dataset
    assert robosuite.__version__ == '1.4.1'
    
    copyfile(args.dataset, args.output_dataset)
    
    f = h5py.File(args.output_dataset, "r+")

    env_args = json.loads(f["data"].attrs["env_args"])
    env_name = env_args["env_name"]

    env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
    env_type = EnvUtils.get_env_type(env_meta=env_meta)

    # need to make sure ObsUtils knows which observations are images, but it doesn't matter 
    # for playback since observations are unused. Pass a dummy spec here.
    dummy_spec = dict(
        obs=dict(
                low_dim=["robot0_eef_pos"],
                rgb=[],
            ),
    )
    ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

    env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
    env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=False, render_offscreen=True)
    env.reset()

    for demo_key in list(f["data"].keys()):
        ep_data_grp = f["data/{}".format(demo_key)]
        model_file = ep_data_grp.attrs["model_file"]
        
        coverted_model_file = convert_xml(model_file, env_name, env)
        ep_data_grp.attrs["model_file"] = coverted_model_file
        
    env_args = json.loads(f["data"].attrs["env_args"])
    env_args["env_version"] = robosuite.__version__ 
    f["data"].attrs["env_args"] = json.dumps(env_args, indent=4)

    f.close()
